###########################################################################
# Purpose: Adjusted survival analysis
###########################################################################

suppressPackageStartupMessages({
  library(dplyr)
  library(speedglm)
  library(Hmisc)     # rcspline.eval
  library(reshape2)  # dcast
  library(boot)
  library(parallel)
})

###############################################################################
# Parameters
###############################################################################
mode_formula <- "nephropathy ~ arm_int + Time + TimeSq + Trial + Trial_spline + SDI_b + Hba1c_cat_b + Age_b + Age_b_spline + Female +eGFR_ff3_b+ eGFR_ff3_b_spline + LDL_ff3_b + LDL_ff3_b_spline + PayerCommerical_b + AnticoagulentRatio_b + NumMeds_b + NumMeds_b_spline + NumLabs_b + NumLabs_b_spline + NumVisits_b + NumVisits_b_spline + Obesity_cat_b + CumulativeBurden_b + CumulativeBurden_b_spline + HeartFailure_b + nafld_nash_h_b + HospShortBinary_b + SteroidUseShortTerm_b + HypercholesterolemiaUse_b + AntiplateletUse_b + CCIBinary_b + SUandeGFR_ff3_b + combinedandeGFR_ff3_b  + SUandLDL_ff3_b + combinedandLDL_ff3_b + I(Age_b*CumulativeBurden_b) + SUandAge_b + combinedandAge_b + SUandCumulativeBurden_b + combinedandCumulativeBurden_b + SUandAgeandCumulativeBurden + combinedandAgeandCumulativeBurden  + SUandTime + combinedandTime + SUandHeartFailure_b + combinedandHeartFailure_b + SUandHba1c_cat_b +combinedandHba1c_cat_b+ ldl_na + sdi_cat + insr_cat"
baseline <- df %>% dplyr::filter(Time == 0)
# Number of individuals
n <- nrow(unique(df[,c("PatId"])))
# Duration of follow-up for standardizing the results
K1 <- 72 
          
###############################################################################
# 1) Fit pooled logistic regression model 
###############################################################################
plr_fit <- speedglm::speedglm(model_formula, data = df, family = binomial(),
                              set.default = list(row.chunk = 200))

###############################################################################
# 2) Standardized survival estimates 
###############################################################################
# Helper to build the per-arm "copy" dataset, keeping column names identical
make_arm_copy <- function(baseline_df, arm_value, n_ids, K1) {
  copy <- baseline_df[rep(1:n_ids, each = K1), ]
  copy$Time <- rep(0:(K1 - 1), times = n_ids)
  copy$TimeSq <- copy$Time^2
  
  copy <- copy %>%
    mutate(
      arm_int = factor(arm_value),
      
      # Interaction terms (identical names to your snippet)
      SUandTime = ifelse(arm_value == 2, Time, 0),
      combinedandTime = ifelse(arm_value == 3, Time, 0),
      SUandeGFR_ff3_b = ifelse(arm_value == 2, eGFR_ff3_b, 0),
      combinedandeGFR_ff3_b = ifelse(arm_value == 3, eGFR_ff3_b, 0),
      SUandLDL_ff3_b = ifelse(arm_value == 2, LDL_ff3_b, 0),
      combinedandLDL_ff3_b = ifelse(arm_value == 3, LDL_ff3_b, 0),
      SUandHeartFailure_b = ifelse(arm_value == 2, HeartFailure_b, 0),
      combinedandHeartFailure_b = ifelse(arm_value == 3, HeartFailure_b, 0),
      SUandAge_b = ifelse(arm_value == 2, Age_b, 0),
      combinedandAge_b = ifelse(arm_value == 3, Age_b, 0),
      SUandCumulativeBurden_b = ifelse(arm_value == 2, CumulativeBurden_b, 0),
      combinedandCumulativeBurden_b = ifelse(arm_value == 3, CumulativeBurden_b, 0),
      SUandAgeandCumulativeBurden = ifelse(arm_value == 2, Age_b * CumulativeBurden_b, 0),
      combinedandAgeandCumulativeBurden = ifelse(arm_value == 3, Age_b * CumulativeBurden_b, 0),
      SUandHba1c_cat_b = relevel(factor(ifelse(arm_value == 2, Hba1c_cat_b, 1), levels = c(1, 2, 3, 4)), ref = 1),
      combinedandHba1c_cat_b = relevel(factor(ifelse(arm_value == 3, Hba1c_cat_b, 1), levels = c(1, 2, 3, 4)), ref = 1)
    )
  
  copy
}

# Create copies (same as your DPP4_copy / SU_copy / combined_copy blocks)
DPP4_copy <- make_arm_copy(baseline, arm_value = 1, n_ids = n, K1 = K1)
SU_copy <- make_arm_copy(baseline, arm_value = 2, n_ids = n, K1 = K1)
combined_copy <- make_arm_copy(baseline, arm_value = 3, n_ids = n, K1 = K1)

# Predict discrete-time survival density per interval and cumprod by person
DPP4_copy$p <- 1 - predict(plr_fit, newdata = DPP4_copy, type = "response")
SU_copy$p <- 1 - predict(plr_fit, newdata = SU_copy, type = "response")
combined_copy$p <- 1 - predict(plr_fit, newdata = combined_copy, type = "response")

DPP4_copy <- DPP4_copy %>% arrange(PatId, Time) %>% group_by(PatId) %>% mutate(s = cumprod(p)) %>% ungroup()
SU_copy   <- SU_copy   %>% arrange(PatId, Time) %>% group_by(PatId) %>% mutate(s = cumprod(p)) %>% ungroup()
combined_copy <- combined_copy %>% arrange(PatId, Time) %>% group_by(PatId) %>% mutate(s = cumprod(p)) %>% ungroup()

all_surv <- bind_rows(DPP4_copy, SU_copy, combined_copy) %>%
  select(s, arm_int, Time)

results <- all_surv %>%
  group_by(Time, arm_int) %>%
  summarise(mean_survival = mean(as.numeric(s), na.rm = TRUE), .groups = "drop") %>%
  mutate(Time = Time + 1) # end of interval

Time0 <- bind_rows(
  c(Time = 0, arm_int = 1, mean_survival = 1),
  c(Time = 0, arm_int = 2, mean_survival = 1),
  c(Time = 0, arm_int = 3, mean_survival = 1)
) %>% mutate(arm_int = factor(arm_int))

results <- bind_rows(Time0, results) %>%
  mutate(arm = factor(arm_int, levels = c(1,2,3), labels = c("DPP4", "SU", "combined")))

wideres <- dcast(results, Time ~ arm, value.var = "mean_survival")

# Same summary stats (RD, CIR, cHR)
wideres <- wideres %>%
  mutate(
    SU_DPP4_RD = (1 - SU) - (1 - DPP4),
    SU_DPP4_CIR = (1 - SU) / (1 - DPP4),
    combined_DPP4_RD = (1 - combined) - (1 - DPP4),
    combined_DPP4_CIR = (1 - combined) / (1 - DPP4),
    combined_SU_RD = (1 - combined) - (1 - SU),
    combined_SU_CIR = (1 - combined) / (1 - SU)
  )

###############################################################################
# 3) Bootstrap (kept very close to your pattern; simplified printing/cleanup)
###############################################################################
risk_boot <- function(data, indices, data_frame, model_formula, K1) {
  # data is vector of PatIds; indices selects PatIds for this bootstrap sample
  boot_ids <- data.frame(PatId = data[indices], BootId = seq_along(indices))
  
  d <- left_join(boot_ids, data_frame, by = "PatId")
  
  # Fit pooled logistic on bootstrap person-time
  fit_b <- speedglm::speedglm(model_formula, data = d, family = binomial(),
                              set.default = list(row.chunk = 200))
  
  # Baseline for this bootstrap (one row per BootId)
  b <- d %>% filter(Time == 0) %>% select(-Time, -TimeSq, -TimeCu,
                                          -Time_spline, -Time_spline_v1, -Time_spline_v2)
  
  n_boot <- nrow(boot_ids)
  
  # Create copies by BootId (mirror your DPP4_boot / SU_boot / combined_boot)
  make_boot_copy <- function(arm_value) {
    copy <- b[rep(1:n_boot, each = K1), ]
    copy$Time <- rep(0:(K1 - 1), times = n_boot)
    
    copy <- copy %>%
      mutate(
        PatId = BootId,                      # reuse PatId slot for grouping
        arm_int = factor(arm_value),
        SUandTime = ifelse(arm_value == 2, Time, 0),
        combinedandTime = ifelse(arm_value == 3, Time, 0),
        SUandeGFR_ff3_b = ifelse(arm_value == 2, eGFR_ff3_b, 0),
        combinedandeGFR_ff3_b = ifelse(arm_value == 3, eGFR_ff3_b, 0),
        SUandLDL_ff3_b = ifelse(arm_value == 2, LDL_ff3_b, 0),
        combinedandLDL_ff3_b = ifelse(arm_value == 3, LDL_ff3_b, 0),
        SUandHeartFailure_b = ifelse(arm_value == 2, HeartFailure_b, 0),
        combinedandHeartFailure_b = ifelse(arm_value == 3, HeartFailure_b, 0),
        SUandAge_b = ifelse(arm_value == 2, Age_b, 0),
        combinedandAge_b = ifelse(arm_value == 3, Age_b, 0),
        SUandCumulativeBurden_b = ifelse(arm_value == 2, CumulativeBurden_b, 0),
        combinedandCumulativeBurden_b = ifelse(arm_value == 3, CumulativeBurden_b, 0),
        SUandAgeandCumulativeBurden = ifelse(arm_value == 2, Age_b * CumulativeBurden_b, 0),
        combinedandAgeandCumulativeBurden = ifelse(arm_value == 3, Age_b * CumulativeBurden_b, 0),
        SUandHba1c_cat_b = relevel(factor(ifelse(arm_value == 2, Hba1c_cat_b, 1), levels = c(1, 2, 3, 4)), ref = 1),
        combinedandHba1c_cat_b = relevel(factor(ifelse(arm_value == 3, Hba1c_cat_b, 1), levels = c(1, 2, 3, 4)), ref = 1)
      )
    copy
  }
  
  DPP4_boot <- make_boot_copy(1)
  SU_boot <- make_boot_copy(2)
  combined_boot <- make_boot_copy(3)
  
  # Predict and compute survival by BootId (via PatId overwritten above)
  DPP4_boot$p <- 1 - predict(fit_b, newdata = DPP4_boot, type = "response")
  SU_boot$p   <- 1 - predict(fit_b, newdata = SU_boot, type = "response")
  combined_boot$p <- 1 - predict(fit_b, newdata = combined_boot, type = "response")
  
  DPP4_boot <- DPP4_boot %>% arrange(PatId, Time) %>% group_by(PatId) %>% mutate(s = cumprod(p)) %>% ungroup()
  SU_boot   <- SU_boot   %>% arrange(PatId, Time) %>% group_by(PatId) %>% mutate(s = cumprod(p)) %>% ungroup()
  combined_boot <- combined_boot %>% arrange(PatId, Time) %>% group_by(PatId) %>% mutate(s = cumprod(p)) %>% ungroup()
  
  all_boot <- bind_rows(DPP4_boot, SU_boot, combined_boot) %>%
    select(s, arm_int, Time)
  
  results_boot <- all_boot %>%
    group_by(Time, arm_int) %>%
    summarise(mean_survival = mean(as.numeric(s), na.rm = TRUE), .groups = "drop") %>%
    mutate(Time = Time + 1)
  
  Time0 <- bind_rows(
    c(Time = 0, arm_int = 1, mean_survival = 1),
    c(Time = 0, arm_int = 2, mean_survival = 1),
    c(Time = 0, arm_int = 3, mean_survival = 1)
  ) %>% mutate(arm_int = factor(arm_int))
  
  results_boot <- bind_rows(Time0, results_boot) %>%
    mutate(arm = factor(arm_int, levels = c(1,2,3), labels = c("DPP4", "SU", "combined")))
  
  wide_boot <- dcast(results_boot, Time ~ arm, value.var = "mean_survival")
  
  wide_boot <- wide_boot %>%
    mutate(
      SU_DPP4_RD = (1 - SU) - (1 - DPP4),
      SU_DPP4_CIR = (1 - SU) / (1 - DPP4),
      combined_DPP4_RD = (1 - combined) - (1 - DPP4),
      combined_DPP4_CIR = (1 - combined) / (1 - DPP4),
      combined_SU_RD = (1 - combined) - (1 - SU),
      combined_SU_CIR = (1 - combined) / (1 - SU)
    )
  
  # Return vector in the same “time grid” style you used
  times_keep <- c(12, 24, 36, 48, 60, 72)
  
  out <- c(
    wide_boot$SU_DPP4_RD[wide_boot$Time %in% times_keep],
    wide_boot$SU_DPP4_CIR[wide_boot$Time %in% times_keep],
    
    wide_boot$combined_DPP4_RD[wide_boot$Time %in% times_keep],
    wide_boot$combined_DPP4_CIR[wide_boot$Time %in% times_keep],
    
    wide_boot$combined_SU_RD[wide_boot$Time %in% times_keep],
    wide_boot$combined_SU_CIR[wide_boot$Time %in% times_keep]
  )
  
  return(out)
}

###############################################################################
# Run bootstrap (parallel pattern similar to your snippet, simplified)
###############################################################################
num_boot <- 100          # increase for final runs
boot_conf <- 0.95
boot_method <- "perc"    # "perc" or "norm"
times_keep <- c(12, 24, 36, 48, 60, 72)

PatIds <- as.vector(unique(df$PatId))
num_cores <- max(1, parallel::detectCores() - 1)

cl <- makeCluster(num_cores)
clusterEvalQ(cl, {
  library(dplyr)
  library(speedglm)
  library(Hmisc)
  library(reshape2)
})
clusterExport(cl, c("df", "model_formula", "K1", "risk_boot"))

clusterSetRNGStream(cl, iseed = 1)

risk_results <- clusterCall(
  cl,
  boot,
  PatIds,
  risk_boot,
  R = num_boot / length(cl),
  data_frame = df,
  model_formula = model_formula,
  K1 = K1
)

stopCluster(cl)

# combine cluster boot objects into one (same idea as your fixboot step)
fixboot <- function(boot_list) {
  # boot_list is a list of boot objects from each worker
  t0 <- boot_list[[1]]$t0
  t  <- do.call(rbind, lapply(boot_list, `[[`, "t"))
  structure(list(t0 = t0, t = t, R = nrow(t)), class = "boot")
}
risk_results <- fixboot(risk_results)

###############################################################################
# Format bootstrap CIs 
###############################################################################
stat_names <- c("RD", "CIR")
comp_names <- c("SU_vs_DPP4", "combined_vs_DPP4", "combined_vs_SU")

# indices in returned vector (36 total now):
# 1..6    SU_vs_DPP4 RD
# 7..12   SU_vs_DPP4 CIR
# 13..18  combined_vs_DPP4 RD
# 19..24  combined_vs_DPP4 CIR
# 25..30  combined_vs_SU RD
# 31..36  combined_vs_SU CIR
make_index_map <- function(times_keep) {
  idx <- 1
  out <- list()
  for (comp in comp_names) {
    for (stat in stat_names) {
      for (t in times_keep) {
        out[[length(out) + 1]] <- data.frame(
          comparison = comp,
          statistic = stat,
          time = t,
          index = idx
        )
        idx <- idx + 1
      }
    }
  }
  bind_rows(out)
}

index_map <- make_index_map(times_keep)

boot_table <- index_map %>%
  rowwise() %>%
  mutate(
    lci = {
      ci <- boot.ci(risk_results, conf = boot_conf, type = boot_method, index = index)
      if (boot_method == "perc") ci$percent[, 4] else ci$normal[, 2]
    },
    uci = {
      ci <- boot.ci(risk_results, conf = boot_conf, type = boot_method, index = index)
      if (boot_method == "perc") ci$percent[, 5] else ci$normal[, 3]
    },
    estimate = risk_results$t0[index]
  ) %>%
  ungroup() %>%
  mutate(
    # RD shown as percentage points
    estimate = ifelse(statistic == "RD", 100 * estimate, estimate),
    lci      = ifelse(statistic == "RD", 100 * lci, lci),
    uci      = ifelse(statistic == "RD", 100 * uci, uci),
    
    estimate = round(estimate, 3),
    lci = round(lci, 3),
    uci = round(uci, 3),
    result = paste0(estimate, " (", lci, ", ", uci, ")")
  ) %>%
  select(comparison, statistic, time, estimate, lci, uci, result)

print(head(boot_table, 12))
# write.csv(boot_table, "boot_contrasts.csv", row.names = FALSE)